{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![MIT Deep Learning](https://deeplearning.mit.edu/files/images/github/mit_deep_learning.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<table align=\"center\">\n",
    "  <td align=\"center\"><a target=\"_blank\" href=\"https://deeplearning.mit.edu\">\n",
    "        <img src=\"https://deeplearning.mit.edu/files/images/github/icon_mit.png\" style=\"padding-bottom:5px;\" />\n",
    "      Visit MIT Deep Learning</a></td>\n",
    "  <td align=\"center\"><a target=\"_blank\" href=\"http://colab.research.google.com/github/lexfridman/mit-deep-learning/blob/master/tutorial_gans/tutorial_gans.ipynb\">\n",
    "        <img src=\"https://deeplearning.mit.edu/files/images/github/icon_google_colab.png\" style=\"padding-bottom:5px;\" />Run in Google Colab</a></td>\n",
    "  <td align=\"center\"><a target=\"_blank\" href=\"https://github.com/lexfridman/mit-deep-learning/blob/master/tutorial_gans/tutorial_gans.ipynb\">\n",
    "        <img src=\"https://deeplearning.mit.edu/files/images/github/icon_github.png\" style=\"padding-bottom:5px;\"  />View Source on GitHub</a></td>\n",
    "  <td align=\"center\"><a target=\"_blank\" align=\"center\" href=\"https://www.youtube.com/watch?v=O5xeyoRL95U&list=PLrAXtmErZgOeiKm4sgNOknGvNjby9efdf\">\n",
    "        <img src=\"https://deeplearning.mit.edu/files/images/github/icon_youtube.png\" style=\"padding-bottom:5px;\" />Watch YouTube Videos</a></td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generative Adversarial Networks (GANs)\n",
    "\n",
    "This tutorial accompanies lectures of the [MIT Deep Learning](https://deeplearning.mit.edu) series. Acknowledgement to amazing people involved is provided throughout the tutorial and at the end. Introductory lectures on GANs include the following (with more coming soon):\n",
    "\n",
    "<table>\n",
    "  <td align=\"center\" style=\"text-align: center;\">    \n",
    "    <a target=\"_blank\" href=\"https://www.youtube.com/watch?list=PLrAXtmErZgOeiKm4sgNOknGvNjby9efdf&v=O5xeyoRL95U\">\n",
    "        <img src=\"https://i.imgur.com/FfQVV8q.png\" style=\"padding-bottom:5px;\" />\n",
    "        (Lecture) Deep Learning Basics: Intro and Overview\n",
    "    </a>\n",
    "  </td>\n",
    "  <td align=\"center\" style=\"text-align: center;\">\n",
    "      <a target=\"_blank\" href=\"https://www.youtube.com/watch?list=PLrAXtmErZgOeiKm4sgNOknGvNjby9efdf&v=53YvP6gdD7U\">\n",
    "        <img src=\"https://i.imgur.com/vbNjF3N.png\" style=\"padding-bottom:5px;\" />\n",
    "          (Lecture) Deep Learning State of the Art 2019\n",
    "      </a>\n",
    "  </td>\n",
    "</table>\n",
    "\n",
    "Generative Adversarial Networks (GANs) are a framework for training networks optimized for generating new realistic samples from a particular representation. In its simplest form, the training process involves two networks. One network, called the generator, generates new data instances, trying to fool the other network, the discriminator, that classifies images as real or fake. This original form is illustrated as follows (where #6 refers to one of 7 architectures described in the [Deep Learning Basics tutorial](https://github.com/lexfridman/mit-deep-learning/blob/master/tutorial_deep_learning_basics/deep_learning_basics.ipynb)):\n",
    "\n",
    "<img src=\"https://i.imgur.com/LweaD1s.png\" width=\"600px\">\n",
    "\n",
    "There are broadly 3 categories of GANs:\n",
    "\n",
    "1. **Unsupervised GANs**: The generator network takes random noise as input and produces a photo-realistic image that appears very similar to images that appear in the training dataset. Examples include the [original version of GAN](https://arxiv.org/abs/1406.2661), [DC-GAN](https://arxiv.org/abs/1511.06434), [pg-GAN](https://arxiv.org/abs/1710.10196), etc.\n",
    "3. **Style-Transfer GANs** - Translate images from one domain to another (e.g., from horse to zebra, from sketch to colored images). Examples include [CycleGAN](https://junyanz.github.io/CycleGAN/) and [pix2pix](https://phillipi.github.io/pix2pix/).\n",
    "2. **Conditional GANs** - Jointly learn on features along with images to generate images conditioned on those features (e.g., generating an instance of a particular class). Examples includes [Conditional GAN](https://arxiv.org/abs/1411.1784), [AC-GAN](https://arxiv.org/abs/1610.09585), [Stack-GAN](https://github.com/hanzhanggit/StackGAN), and [BigGAN](https://arxiv.org/abs/1809.11096).\n",
    "\n",
    "First, we illustrate BigGAN, a state-of-the-art conditional GAN from DeepMind. This illustration is based on the [BigGAN TF Hub Demo](https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/biggan_generation_with_tf_hub.ipynb) and the BigGAN generators on [TF Hub](https://tfhub.dev/deepmind/biggan-256). See the [BigGAN paper on arXiv](https://arxiv.org/abs/1809.11096) [1] for more information about these models.\n",
    "\n",
    "We'll be adding more parts to this tutorial as additional lectures come out."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 1: BigGAN\n",
    "\n",
    "We recommend that you run this this notebook in the cloud on Google Colab. If you have not done so yet, consider following the setup steps in the [Deep Learning Basics tutorial](https://github.com/lexfridman/mit-deep-learning) and reading the [Deep Learning Basics: Introduction and Overview with TensorFlow](https://medium.com/tensorflow/mit-deep-learning-basics-introduction-and-overview-with-tensorflow-355bcd26baf0) blog post."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/device:GPU:0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# basics\n",
    "import io\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "# deep learning\n",
    "from scipy.stats import truncnorm\n",
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub\n",
    "\n",
    "# visualization\n",
    "from IPython.core.display import HTML\n",
    "#!pip install imageio\n",
    "import imageio\n",
    "import base64\n",
    "\n",
    "# check that tensorflow GPU is enabled\n",
    "tf.test.gpu_device_name() # returns empty string if using CPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load BigGAN generator module from TF Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading BigGAN module from: https://tfhub.dev/deepmind/biggan-512/1\n",
      "INFO:tensorflow:Using /tmp/tfhub_modules to cache modules.\n",
      "INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n"
     ]
    }
   ],
   "source": [
    "# comment out the TF Hub module path you would like to use\n",
    "# module_path = 'https://tfhub.dev/deepmind/biggan-128/1'  # 128x128 BigGAN\n",
    "# module_path = 'https://tfhub.dev/deepmind/biggan-256/1'  # 256x256 BigGAN\n",
    "module_path = 'https://tfhub.dev/deepmind/biggan-512/1'  # 512x512 BigGAN\n",
    "\n",
    "tf.reset_default_graph()\n",
    "print('Loading BigGAN module from:', module_path)\n",
    "module = hub.Module(module_path)\n",
    "inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)\n",
    "          for k, v in module.get_input_info_dict().items()}\n",
    "output = module(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Functions for Sampling and Interpolating the Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_z = inputs['z']\n",
    "input_y = inputs['y']\n",
    "input_trunc = inputs['truncation']\n",
    "\n",
    "dim_z = input_z.shape.as_list()[1]\n",
    "vocab_size = input_y.shape.as_list()[1]\n",
    "\n",
    "# sample truncated normal distribution based on seed and truncation parameter\n",
    "def truncated_z_sample(truncation=1., seed=None):\n",
    "    state = None if seed is None else np.random.RandomState(seed)\n",
    "    values = truncnorm.rvs(-2, 2, size=(1, dim_z), random_state=state)\n",
    "    return truncation * values\n",
    "\n",
    "# convert `index` value to a vector of all zeros except for a 1 at `index`\n",
    "def one_hot(index, vocab_size=vocab_size):\n",
    "    index = np.asarray(index)\n",
    "    if len(index.shape) == 0: # when it's a scale convert to a vector of size 1\n",
    "        index = np.asarray([index])\n",
    "    assert len(index.shape) == 1\n",
    "    num = index.shape[0]\n",
    "    output = np.zeros((num, vocab_size), dtype=np.float32)\n",
    "    output[np.arange(num), index] = 1\n",
    "    return output\n",
    "\n",
    "def one_hot_if_needed(label, vocab_size=vocab_size):\n",
    "    label = np.asarray(label)\n",
    "    if len(label.shape) <= 1:\n",
    "        label = one_hot(label, vocab_size)\n",
    "    assert len(label.shape) == 2\n",
    "    return label\n",
    "\n",
    "# using vectors of noise seeds and category labels, generate images\n",
    "def sample(sess, noise, label, truncation=1., batch_size=8, vocab_size=vocab_size):\n",
    "    noise = np.asarray(noise)\n",
    "    label = np.asarray(label)\n",
    "    num = noise.shape[0]\n",
    "    if len(label.shape) == 0:\n",
    "        label = np.asarray([label] * num)\n",
    "    if label.shape[0] != num:\n",
    "        raise ValueError('Got # noise samples ({}) != # label samples ({})'\n",
    "                         .format(noise.shape[0], label.shape[0]))\n",
    "    label = one_hot_if_needed(label, vocab_size)\n",
    "    ims = []\n",
    "    for batch_start in range(0, num, batch_size):\n",
    "        s = slice(batch_start, min(num, batch_start + batch_size))\n",
    "        feed_dict = {input_z: noise[s], input_y: label[s], input_trunc: truncation}\n",
    "        ims.append(sess.run(output, feed_dict=feed_dict))\n",
    "    ims = np.concatenate(ims, axis=0)\n",
    "    assert ims.shape[0] == num\n",
    "    ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)\n",
    "    ims = np.uint8(ims)\n",
    "    return ims\n",
    "\n",
    "def interpolate(a, b, num_interps):\n",
    "    alphas = np.linspace(0, 1, num_interps)\n",
    "    assert a.shape == b.shape, 'A and B must have the same shape to interpolate.'\n",
    "    return np.array([(1-x)*a + x*b for x in alphas])\n",
    "\n",
    "def interpolate_and_shape(a, b, steps):\n",
    "    interps = interpolate(a, b, steps)\n",
    "    return (interps.transpose(1, 0, *range(2, len(interps.shape))).reshape(steps, -1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a TensorFlow session and initialize variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "initializer = tf.global_variables_initializer()\n",
    "sess = tf.Session()\n",
    "sess.run(initializer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create video of interpolated BigGAN generator samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# category options: https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a\n",
    "category = 947 # mushroom\n",
    "\n",
    "# important parameter that controls how much variation there is\n",
    "truncation = 0.2 # reasonable range: [0.02, 1]\n",
    "\n",
    "seed_count = 10\n",
    "clip_secs = 36\n",
    "\n",
    "seed_step = int(100 / seed_count)\n",
    "interp_frames = int(clip_secs * 30 / seed_count)  # interpolation frames\n",
    "\n",
    "cat1 = category\n",
    "cat2 = category\n",
    "all_imgs = []\n",
    "\n",
    "for i in range(seed_count):\n",
    "    seed1 = i * seed_step # good range for seed is [0, 100]\n",
    "    seed2 = ((i+1) % seed_count) * seed_step\n",
    "    \n",
    "    z1, z2 = [truncated_z_sample(truncation, seed) for seed in [seed1, seed2]]\n",
    "    y1, y2 = [one_hot([category]) for category in [cat1, cat2]]\n",
    "\n",
    "    z_interp = interpolate_and_shape(z1, z2, interp_frames)\n",
    "    y_interp = interpolate_and_shape(y1, y2, interp_frames)\n",
    "\n",
    "    imgs = sample(sess, z_interp, y_interp, truncation=truncation)\n",
    "    \n",
    "    all_imgs.extend(imgs[:-1])\n",
    "\n",
    "# save the video for displaying in the next cell, this is way more space efficient than the gif animation\n",
    "imageio.mimsave('gan.mp4', all_imgs, fps=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<video autoplay loop>\n",
       "  <source src=\"gan.mp4\" type=\"video/mp4\">\n",
       "</video>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%HTML\n",
    "<video autoplay loop>\n",
    "  <source src=\"gan.mp4\" type=\"video/mp4\">\n",
    "</video>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above code should generate a 512x512 video version of the following:\n",
    "\n",
    "![BigGAN mushroom](https://i.imgur.com/TA9uh1a.gif)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Acknowledgements\n",
    "\n",
    "The content of this tutorial is based on and inspired by the work of [TensorFlow team](https://www.tensorflow.org) (see their [Colab notebooks](https://www.tensorflow.org/tutorials/)), [Google DeepMind](https://deepmind.com/), our [MIT Human-Centered AI team](https://hcai.mit.edu), and individual pieces referenced in the [MIT Deep Learning](https://deeplearning.mit.edu) course slides.\n",
    "\n",
    "TF Colab and TF Hub content is copyrighted to The TensorFlow Authors (2018). Licensed under the Apache License, Version 2.0 (the \"License\"); http://www.apache.org/licenses/LICENSE-2.0"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  },
  "colab": {
      "name": "tutorial_gans.ipynb",
      "version": "0.1.0",
      "provenance": []
    },
    "accelerator": "GPU"
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
